Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6805df0e
Commit
6805df0e
authored
Jun 18, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into gelu
parents
1fdbe3fe
e4584d91
Changes
68
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
589 additions
and
200 deletions
+589
-200
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+120
-27
library/include/ck/library/host_tensor/host_reduction.hpp
library/include/ck/library/host_tensor/host_reduction.hpp
+19
-14
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+2
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
...reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
..._instance/gpu/reduce/device_reduce_instance_blockwise.hpp
+19
-24
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp
...u/reduce/device_reduce_instance_multiblock_atomic_add.hpp
+20
-29
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp
...instance/gpu/reduce/device_reduce_instance_threadwise.hpp
+19
-24
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+12
-11
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
+6
-9
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
+6
-9
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
+6
-9
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
...xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
+6
-9
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt
...peration_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt
+10
-0
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
+81
-0
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
+81
-0
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
+81
-0
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
+78
-0
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
+7
-10
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
+7
-10
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
...ce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
+7
-10
No files found.
include/ck/utility/reduction_operator.hpp
View file @
6805df0e
...
...
@@ -28,6 +28,7 @@
#include "config.hpp"
#include "data_type.hpp"
#include "type.hpp"
namespace
ck
{
...
...
@@ -54,64 +55,92 @@ namespace reduce {
// accumulated index also need be
// changed.
template
<
class
T
>
struct
Add
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
};
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
+
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
};
template
<
class
T
>
struct
Mul
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
static_cast
<
T
>
(
1.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
1.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
*
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
};
template
<
class
T
>
struct
Max
{
using
dataType
=
T
;
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Lowest
();
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
...
...
@@ -120,28 +149,41 @@ struct Max
}
};
template
<
class
T
>
struct
Min
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Max
();
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Max
();
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_min to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
{
a
=
b
;
...
...
@@ -150,28 +192,41 @@ struct Min
}
};
template
<
class
T
>
struct
AMax
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the AMax accumulator!"
);
if
(
a
<
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the AMax accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
...
...
@@ -181,7 +236,7 @@ struct AMax
};
template
<
typename
T
>
T
GetIdentityValue
ue
ForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
constexpr
T
GetIdentityValueForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
T
result
=
ck
::
type_convert
<
T
>
(
0.0
f
);
...
...
@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation
return
(
result
);
};
template
<
InMemoryDataOperationEnum
Operation
,
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicAdd
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicMax
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Set
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Add
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
};
};
// end of namespace reduce
}
// end of namespace ck
...
...
library/include/ck/library/host_tensor/host_reduction.hpp
View file @
6805df0e
...
...
@@ -174,15 +174,18 @@ struct ReductionHost
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
IndexDataType
*
out_indices
)
IndexDataType
*
out_indices
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
if
constexpr
(
OutputIndex
)
{
RunImpl_with_index
(
alpha
,
in_data
,
beta
,
out_data
,
out_indices
);
RunImpl_with_index
(
alpha
,
in_data
,
beta
,
out_data
,
out_indices
,
in_elementwise_op
,
acc_elementwise_op
);
}
else
{
RunImpl_no_index
(
alpha
,
in_data
,
beta
,
out_data
);
RunImpl_no_index
(
alpha
,
in_data
,
beta
,
out_data
,
in_elementwise_op
,
acc_elementwise_op
);
};
};
...
...
@@ -190,7 +193,9 @@ struct ReductionHost
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
IndexDataType
*
out_indices
)
IndexDataType
*
out_indices
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
...
...
@@ -200,12 +205,10 @@ struct ReductionHost
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
InElementwiseOperation
in_elementwise_op
(
divider
);
AccElementwiseOperation
acc_elementwise_op
(
divider
);
if
constexpr
(
NumInvariantDim
==
0
)
{
AccDataType
accuVal
=
ReduceOperation
::
GetIdentityValue
();
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
IndexDataType
accuIndex
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
...
...
@@ -236,7 +239,7 @@ struct ReductionHost
else
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOperation
::
GetIdentityValue
();
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
IndexDataType
accuIndex
=
0
;
auto
offset_invariant
=
...
...
@@ -297,7 +300,12 @@ struct ReductionHost
};
};
void
RunImpl_no_index
(
float
alpha
,
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
)
void
RunImpl_no_index
(
float
alpha
,
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
...
...
@@ -306,12 +314,9 @@ struct ReductionHost
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
InElementwiseOperation
in_elementwise_op
(
divider
);
AccElementwiseOperation
acc_elementwise_op
(
divider
);
if
constexpr
(
NumInvariantDim
==
0
)
{
AccDataType
accuVal
=
ReduceOperation
::
GetIdentityValue
();
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
for
(
const
auto
&
reduce_index
:
reduce_dim_indexes
)
{
...
...
@@ -338,7 +343,7 @@ struct ReductionHost
else
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOperation
::
GetIdentityValue
();
AccDataType
accuVal
=
ReduceOperation
::
template
GetIdentityValue
<
AccDataType
>
();
auto
offset_invariant
=
get_offset_from_index
<
NumInvariantDim
>
(
invariantStrides
,
invariant_index
);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
6805df0e
...
...
@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
float
v_in
;
arg
.
in_element_op_
(
v_in
,
v_acc
);
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_in
);
arg
.
in_element_op_
(
v_acc
,
v_acc
);
arg
.
input_
(
n
,
c
,
wi
)
=
ck
::
type_convert
<
InDataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f_ncw
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
View file @
6805df0e
...
...
@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
arg
.
a_element_op_
(
a
,
arg
.
a_m_k_
(
m
,
k
));
arg
.
b_element_op_
(
b
,
arg
.
b_k_n_
(
k
,
n
));
arg
.
a_element_op_
(
a
,
static_cast
<
AccDataType
>
(
arg
.
a_m_k_
(
m
,
k
))
)
;
arg
.
b_element_op_
(
b
,
static_cast
<
AccDataType
>
(
arg
.
b_k_n_
(
k
,
n
))
)
;
acc
+=
a
*
b
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
View file @
6805df0e
...
...
@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>
;
#endif
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
template
<
ReduceTensorOp
ReduceOpId
>
using
deviceReduceBlockWisePtrType
=
DeviceReducePtr
<
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
>
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
>
;
template
<
typename
InDataType
,
typename
AccDataType
,
...
...
@@ -75,14 +75,13 @@ template <typename InDataType,
bool
PropagateNan
,
bool
UseIndex
>
void
add_device_reduce_instance_blockwise
(
std
::
vector
<
deviceReduceBlockWisePtrType
<
AccDataType
,
ReduceOpId
>>&
device_op_instances
)
std
::
vector
<
deviceReduceBlockWisePtrType
<
ReduceOpId
>>&
device_op_instances
)
{
using
ReduceOperation
=
typename
reduce_binary_operator
<
AccDataType
,
ReduceOpId
>::
opType
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
constexpr
bool
Indexable
=
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
...
...
@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<
compT,
ReduceOpId>> & device_op_instances)
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...
...
@@ -150,21 +149,17 @@ void add_device_reduce_instance_blockwise(
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp
View file @
6805df0e
...
...
@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
>
;
#endif
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOperation
>
using
deviceReduceMultiBlockAtomicAddPtrType
=
DeviceReducePtr
<
typename
reduce_unary_operator
<
AccDataType
,
ReduceOperation
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
AccDataType
,
ReduceOperation
,
true
,
true
>::
AccElementwiseOperation
>
;
template
<
ReduceTensorOp
ReduceOperation
>
using
deviceReduceMultiBlockAtomicAddPtrType
=
DeviceReducePtr
<
typename
reduce_unary_operator
<
ReduceOperation
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
ReduceOperation
,
true
,
true
>::
AccElementwiseOperation
>
;
template
<
typename
InDataType
,
typename
AccDataType
,
...
...
@@ -77,15 +75,13 @@ template <typename InDataType,
bool
PropagateNan
,
bool
UseIndex
>
void
add_device_reduce_instance_multiblock_atomic_add
(
std
::
vector
<
deviceReduceMultiBlockAtomicAddPtrType
<
AccDataType
,
ReduceOpId
>>&
device_op_instances
)
std
::
vector
<
deviceReduceMultiBlockAtomicAddPtrType
<
ReduceOpId
>>&
device_op_instances
)
{
using
ReduceOperation
=
typename
reduce_binary_operator
<
AccDataType
,
ReduceOpId
>::
opType
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
constexpr
bool
Indexable
=
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
...
...
@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \
device_op_instances)
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...
...
@@ -172,21 +167,17 @@ void add_device_reduce_instance_multiblock_atomic_add(
Rank, \
NumReduceDim)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp
View file @
6805df0e
...
...
@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple<
>
;
#endif
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
template
<
ReduceTensorOp
ReduceOpId
>
using
deviceReduceThreadWisePtrType
=
DeviceReducePtr
<
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
>
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
,
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
>
;
template
<
typename
InDataType
,
typename
AccDataType
,
...
...
@@ -61,14 +61,13 @@ template <typename InDataType,
bool
PropagateNan
,
bool
UseIndex
>
void
add_device_reduce_instance_threadwise
(
std
::
vector
<
deviceReduceThreadWisePtrType
<
AccDataType
,
ReduceOpId
>>&
device_op_instances
)
std
::
vector
<
deviceReduceThreadWisePtrType
<
ReduceOpId
>>&
device_op_instances
)
{
using
ReduceOperation
=
typename
reduce_binary_operator
<
AccDataType
,
ReduceOpId
>::
opType
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
AccDataType
,
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
constexpr
bool
Indexable
=
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
...
...
@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise(
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<
compT,
ReduceOpId>> & device_op_instances)
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...
...
@@ -127,21 +126,17 @@ void add_device_reduce_instance_threadwise(
Rank, \
NumReduceDim)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
PropagateNan, \
UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
6805df0e
...
...
@@ -20,7 +20,7 @@ include_directories(BEFORE
function
(
add_instance_library INSTANCE_NAME
)
message
(
"adding instance
${
INSTANCE_NAME
}
"
)
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
ARGN
}
)
add_library
(
${
INSTANCE_NAME
}
OBJECT
${
ARGN
}
)
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
set_target_properties
(
${
INSTANCE_NAME
}
PROPERTIES POSITION_INDEPENDENT_CODE ON
)
endfunction
(
add_instance_library INSTANCE_NAME
)
...
...
@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d)
add_subdirectory
(
gemm_bias_relu
)
add_subdirectory
(
gemm_bias_relu_add
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_bias_add_reduce
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
conv1d_fwd
)
add_subdirectory
(
conv2d_fwd
)
...
...
@@ -45,12 +46,12 @@ add_subdirectory(conv2d_bwd_weight)
add_subdirectory
(
batched_gemm_reduce
)
add_subdirectory
(
gemm_add_add_fastgelu
)
add_library
(
device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
add_library
(
device_operations STATIC
$<TARGET_OBJECTS:device_conv1d_fwd_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_add_instance>
$<TARGET_OBJECTS:device_conv2d_fwd_bias_relu_atomic_add_instance>
$<TARGET_OBJECTS:device_gemm_instance>
...
...
@@ -69,14 +70,14 @@ add_library(device_operations STATIC
add_library
(
composablekernels::device_operations ALIAS device_operations
)
set
(
DEV_OPS_INC_DIRS
set
(
DEV_OPS_INC_DIRS
${
PROJECT_SOURCE_DIR
}
/include/ck/
${
PROJECT_SOURCE_DIR
}
/library/include/ck/
${
PROJECT_SOURCE_DIR
}
/external/include/
)
target_compile_features
(
device_operations PUBLIC
)
set_target_properties
(
device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON
)
target_include_directories
(
device_operations PUBLIC
target_include_directories
(
device_operations PUBLIC
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/utility>
$<INSTALL_INTERFACE:
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck/tensor_description>
...
...
@@ -112,8 +113,8 @@ install(TARGETS device_operations
INCLUDES DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
install
(
DIRECTORY
${
DEV_OPS_INC_DIRS
}
DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
/ck
)
install
(
EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
install
(
EXPORT device_operationsTargets
FILE composable_kerneldevice_operationsTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
View file @
6805df0e
...
...
@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
...
...
@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
View file @
6805df0e
...
...
@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
...
...
@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
View file @
6805df0e
...
...
@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
...
...
@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_in
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
View file @
6805df0e
...
...
@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
...
...
@@ -59,12 +59,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_in
>
;
void
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
...
...
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt
0 → 100644
View file @
6805df0e
set
(
DEVICE_GEMM_REDUCE_INSTANCE_SOURCE
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
)
add_instance_library
(
device_gemm_bias_add_reduce_instance
${
DEVICE_GEMM_REDUCE_INSTANCE_SOURCE
}
)
install
(
TARGETS device_gemm_bias_add_reduce_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_gemm_bias_add_reduce_instance
)
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
0 → 100644
View file @
6805df0e
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
DPtrsGlobal
=
ck
::
Tuple
<
F32
*
,
F32
*>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
UnaryDivide
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
using
ReduceMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// c[m, n] = a[k, m] * b[k, n]
using
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
2
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
2
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
2
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmBiasAddReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
0 → 100644
View file @
6805df0e
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
DPtrsGlobal
=
ck
::
Tuple
<
F32
*
,
F32
*>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
UnaryDivide
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
using
ReduceMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// c[m, n] = a[k, m] * b[n, k]
using
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
2
,
8
,
32
,
32
,
2
,
4
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
2
,
2
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
2
,
8
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
2
,
8
,
32
,
32
,
1
,
2
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Col
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmBiasAddReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
0 → 100644
View file @
6805df0e
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
DPtrsGlobal
=
ck
::
Tuple
<
F32
*
,
F32
*>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
UnaryDivide
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
using
ReduceMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// c[m, n] = a[m, k] * b[n, k]
using
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout| AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
2
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
2
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
2
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmBiasAddReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
0 → 100644
View file @
6805df0e
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp"
#include "element_wise_operation.hpp"
#include "reduction_operator.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
DPtrsGlobal
=
ck
::
Tuple
<
F32
*
,
F32
*>
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
UnaryDivide
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
using
ReduceMemOp
=
ck
::
InMemoryDataOperationEnumSequence
<
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
,
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
>
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// c[m, n] = a[m, k] * b[n, k]
using
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
=
std
::
tuple
<
// clang-format off
//##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
S
<
64
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
S
<
32
,
4
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
,
DeviceGemmBiasAddReduce_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F16
,
F16
,
F32
,
F32
,
F32
,
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ReduceOps
,
DInElementOps
,
DOutElementOps
,
ReduceMemOp
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
S
<
32
,
2
>
,
4
,
1
>
// clang-format on
>
;
void
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmBiasAddReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
View file @
6805df0e
...
...
@@ -21,12 +21,12 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
Unary
Identic
<
F32
,
F32
,
true
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
Unary
Divide
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
...
...
@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = s
>
;
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
View file @
6805df0e
...
...
@@ -21,12 +21,12 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
Unary
Identic
<
F32
,
F32
,
true
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
Unary
Divide
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
...
...
@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = s
>
;
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances
{});
...
...
library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
View file @
6805df0e
...
...
@@ -21,12 +21,12 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReduceSum
=
ck
::
reduce
::
Add
<
F32
>
;
using
ReduceSum
=
ck
::
reduce
::
Add
;
using
ReduceOps
=
ck
::
Tuple
<
ReduceSum
,
ReduceSum
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
Unary
Identic
<
F32
,
F32
,
true
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
Unary
Divide
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
...
...
@@ -62,12 +62,9 @@ using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = s
>
;
void
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmReducePtr
<
DPtrsGlobal
,
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
std
::
vector
<
DeviceGemmReducePtr
<
PassThrough
,
PassThrough
,
PassThrough
,
DInElementOps
,
DOutElementOps
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances
{});
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment